In this tutorial, we make an introduction to neural ordinary differential equations (NODEs) [chen2018neural]. A one-sentence summary of this model family is
We will start this tutorial with a discussion on ODEs. Instead of presenting techniqual details, we will give a practical introduction to ODEs. Next, we formally describe NODEs and show three standard use cases of ODEs: classification, normalizing flows and latent dynamics learning. The lecture will be closed by works that study different aspects of the vanilla NODEs.
Long Break (15min)
NOTE: Most of the code pieces in this tutorial as well as the figures are from the original neural ODE paper and corresponding github repo.
The following cell imports all the required libraries.
%load_ext autoreload
%autoreload 2
!pip install torch torchvision torchdiffeq numpy scipy matplotlib pillow sklearn
import numpy as np
from IPython import display
import time
from sklearn.datasets import make_circles
import torch
import torch.nn as nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from torchdiffeq import odeint
from bnn import BNN
from vae_utils import MNIST_Encoder, MNIST_Decoder
from plot_utils import plot_vdp_trajectories, plot_ode, plot_vdp_animation, plot_cnf_animation, \
plot_mnist_sequences, plot_mnist_predictions, plot_cnf_data
from utils import get_minibatch, mnist_loaders, inf_generator, mnist_accuracy, \
count_parameters, conv3x3, group_norm, Flatten, load_rotating_mnist
Requirement already satisfied: torch in /Users/cagatay/opt/anaconda3/envs/cl/lib/python3.9/site-packages (1.10.1) Requirement already satisfied: torchvision in /Users/cagatay/opt/anaconda3/envs/cl/lib/python3.9/site-packages (0.11.2) Requirement already satisfied: torchdiffeq in /Users/cagatay/opt/anaconda3/envs/cl/lib/python3.9/site-packages (0.2.2) Requirement already satisfied: numpy in /Users/cagatay/opt/anaconda3/envs/cl/lib/python3.9/site-packages (1.19.2) Requirement already satisfied: scipy in /Users/cagatay/opt/anaconda3/envs/cl/lib/python3.9/site-packages (1.6.2) Requirement already satisfied: matplotlib in /Users/cagatay/opt/anaconda3/envs/cl/lib/python3.9/site-packages (3.5.1) Requirement already satisfied: pillow in /Users/cagatay/opt/anaconda3/envs/cl/lib/python3.9/site-packages (9.2.0) Collecting sklearn Downloading sklearn-0.0.post1.tar.gz (3.6 kB) Preparing metadata (setup.py) ... done Requirement already satisfied: typing_extensions in /Users/cagatay/opt/anaconda3/envs/cl/lib/python3.9/site-packages (from torch) (4.3.0) Requirement already satisfied: kiwisolver>=1.0.1 in /Users/cagatay/opt/anaconda3/envs/cl/lib/python3.9/site-packages (from matplotlib) (1.4.2) Requirement already satisfied: fonttools>=4.22.0 in /Users/cagatay/opt/anaconda3/envs/cl/lib/python3.9/site-packages (from matplotlib) (4.25.0) Requirement already satisfied: python-dateutil>=2.7 in /Users/cagatay/opt/anaconda3/envs/cl/lib/python3.9/site-packages (from matplotlib) (2.8.2) Requirement already satisfied: pyparsing>=2.2.1 in /Users/cagatay/opt/anaconda3/envs/cl/lib/python3.9/site-packages (from matplotlib) (3.0.4) Requirement already satisfied: cycler>=0.10 in /Users/cagatay/opt/anaconda3/envs/cl/lib/python3.9/site-packages (from matplotlib) (0.11.0) Requirement already satisfied: packaging>=20.0 in /Users/cagatay/opt/anaconda3/envs/cl/lib/python3.9/site-packages (from matplotlib) (21.3) Requirement already satisfied: six>=1.5 in /Users/cagatay/opt/anaconda3/envs/cl/lib/python3.9/site-packages (from python-dateutil>=2.7->matplotlib) (1.16.0) Building wheels for collected packages: sklearn Building wheel for sklearn (setup.py) ... done Created wheel for sklearn: filename=sklearn-0.0.post1-py3-none-any.whl size=2959 sha256=3a798591e434015d09cc8d24bfb64f5678d5bc0abac2cca9c32256273f9566f9 Stored in directory: /Users/cagatay/Library/Caches/pip/wheels/f8/e0/3d/9d0c2020c44a519b9f02ab4fa6d2a4a996c98d79ab2f569fa1 Successfully built sklearn Installing collected packages: sklearn Successfully installed sklearn-0.0.post1
Ordinary differential equations involve an independent variable, its functions and derivatives of these functions. Formally,
\begin{equation} \dot{\mathbf{x}}(t) = \frac{d\mathbf{x}(t)}{dt} = \lim_{\Delta t \rightarrow 0} \frac{ \mathbf{x}(t + \Delta t) - \mathbf{x}(t)}{\Delta t} = \mathbf{f}(\mathbf{x}(t),\mathbf{u}(t),t), \end{equation}where
Informally speaking, $\mathbf{f}$ tells "how much the state $\mathbf{x}(t)$ would change with an infinitisemal change in $t$". More formally, below equation holds in the limit $\Delta t \rightarrow 0$: \begin{equation} \mathbf{x}(t+\Delta t) = \mathbf{x}(t) + \Delta t \cdot \mathbf{f}(\mathbf{x}(t),\mathbf{u}(t),t). \end{equation}
Note-1: We often refer to $\mathbf{f}$ as vector field or right hand side.
Note-2: Above problem is also known as initial value problem.
Note-3: Throughout this tutorial, we focus on differential functions $\mathbf{f}(\mathbf{x}(t))$ independent of control signals and not explicitly parameterized by time.
An "ODE state solution" $\mathbf{x}(t)$ at time $t\in \mathbb{R}_+$ is given by \begin{equation} \mathbf{x}(t) = \mathbf{x}_0 + \int_0^t \mathbf{f}(\mathbf{x}_\tau)~d\tau, \end{equation} where $\mathbf{x}_0$ denotes the initial value and $\tau \in \mathbb{R}_+$ is an auxiliary time variable.
Note-1: Given an initial value $\mathbf{x}_0$ and a set of time points $\{t_0,t_1,\ldots,t_N\}$, we are often interested in state solutions $\mathbf{x}_{0:N}\equiv\{\mathbf{x}(t_0),\mathbf{x}(t_1),\ldots,\mathbf{x}(t_N)\}$
Note-2: We occassionaly denote $\mathbf{x}_n \equiv \mathbf{x}(t_n)$.
Note-3: Above integral has a tractable form only for very trivial differential functions (recall the integration rules from high school). Therefore, we almost always resort to numerical solvers.
Numerical solvers: TL;DR: A state solution $\mathbf{x}(t)$ can be numerically computed up to a tolerable error.
The celebrated Picard's existence and uniqueness theorem states that an initial value problem has a unique solution if the time differential satisfies the Lipschitz condition. Despite the uniqueness guarantee, there is no general recipe to analytically compute the solution; therefore, we often resort to numerical methods. The simplest and least efficient numerical method is known as Euler's method (above equation). More advanced methods such as Heun's method and Runge-Kutta family of solvers compute average slopes by evaluating $\mathbf{f}(\mathbf{x}(t))$ at multiple locations (speed vs accuracy trade-off). Even more advanced adaptive step solvers set the step size $\Delta t$ dynamically.
In this tutorial, we use torchdiffeq library that implements the adjoint method for gradient estimations.
As an example, we examine Van der Pol (VDP) oscillator, a parametric $2D$ time-invariant ODE system that evolves according to the following: \begin{equation} \label{eq:vdp} \frac{d}{dt} \begin{bmatrix} x_1 \\ x_2 \end{bmatrix} = \begin{bmatrix} x_2 \\ \mu(1-x_2^2)x_2-x_1 \end{bmatrix}. \end{equation}
Our VDP implementatation below follows the two requirements of torchdiffeq:
nn.Module.forward() function must take (time,state) pair as input.# define the differential function
class VDP(nn.Module):
def __init__(self,mu):
''' mu is the only parameter in VDP oscillator '''
super().__init__()
self.mu = mu
def forward(self, t, x):
''' Implements the right hand side
Inputs
t - [] time
x - [N,d] state(s)
Output
\dot{x} - [N,d], time derivative
'''
d1 = x[...,1:2]
d2 = self.mu*(1-x[...,0:1]**2)*x[...,1:2]-x[...,0:1]
return torch.cat([d1,d2],-1)
Next, we instantiate the three ingredients (differential function $\mathbf{f}$, initial value $\mathbf{x}_0$, integration time points $t$), forward integrate, and visualize how integration proceeds.
# create the differential function, needs to be a nn.Module
vdp = VDP(1.0).to(device)
# initial value, of shape [N,n]
x0 = torch.tensor([[1.0,0.0]]).float().to(device)
# integration time points, of shape [T]
ts = torch.linspace(0., 15., 500).to(device)
# forward integration
with torch.no_grad():
X = odeint(vdp, x0, ts) # [T,N,n]
# animation
anim = plot_vdp_animation(ts,X,vdp)
display.HTML(anim.to_jshtml())